package com.example.websocket;

import com.example.service.WebSocketService;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.*;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * WebSocket处理器
 * 处理WebSocket连接、消息发送和接收
 */
@Slf4j
@Component
public class WebSocketHandler implements org.springframework.web.socket.WebSocketHandler {

    @Autowired
    private WebSocketService webSocketService;

    @Autowired
    private ObjectMapper objectMapper;

    // 存储所有WebSocket会话
    private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<>();

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        String userId = getUserId(session);
        if (userId != null) {
            sessions.put(userId, session);
            webSocketService.addSession(userId, session);
            log.info("WebSocket连接建立: userId={}, sessionId={}", userId, session.getId());
            
            // 发送连接成功消息
            sendMessage(session, Map.of(
                "type", "system",
                "content", "连接成功",
                "timestamp", System.currentTimeMillis()
            ));
        } else {
            log.warn("WebSocket连接失败: 无法获取用户ID, sessionId={}", session.getId());
            session.close();
        }
    }

    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
        String userId = getUserId(session);
        if (userId == null) {
            return;
        }

        try {
            String payload = message.getPayload().toString();
            Map<String, Object> messageData = objectMapper.readValue(payload, Map.class);
            String messageType = (String) messageData.get("type");

            log.debug("收到WebSocket消息: userId={}, type={}, message={}", userId, messageType, payload);

            switch (messageType) {
                case "heartbeat":
                    handleHeartbeat(session, userId);
                    break;
                case "chat":
                    handleChatMessage(session, userId, messageData);
                    break;
                case "notification":
                    handleNotificationMessage(session, userId, messageData);
                    break;
                default:
                    log.warn("未知消息类型: {}", messageType);
            }
        } catch (Exception e) {
            log.error("处理WebSocket消息失败: userId={}, error={}", userId, e.getMessage(), e);
        }
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        String userId = getUserId(session);
        log.error("WebSocket传输错误: userId={}, sessionId={}, error={}", 
                userId, session.getId(), exception.getMessage(), exception);
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
        String userId = getUserId(session);
        if (userId != null) {
            sessions.remove(userId);
            webSocketService.removeSession(userId);
            log.info("WebSocket连接关闭: userId={}, sessionId={}, status={}", 
                    userId, session.getId(), closeStatus);
        }
    }

    @Override
    public boolean supportsPartialMessages() {
        return false;
    }

    /**
     * 处理心跳消息
     */
    private void handleHeartbeat(WebSocketSession session, String userId) {
        try {
            sendMessage(session, Map.of(
                "type", "heartbeat",
                "timestamp", System.currentTimeMillis()
            ));
        } catch (Exception e) {
            log.error("发送心跳响应失败: userId={}, error={}", userId, e.getMessage());
        }
    }

    /**
     * 处理聊天消息
     */
    private void handleChatMessage(WebSocketSession session, String userId, Map<String, Object> messageData) {
        // 转发聊天消息给目标用户
        String targetUserId = (String) messageData.get("targetUserId");
        if (targetUserId != null) {
            webSocketService.sendMessageToUser(targetUserId, messageData);
        }
    }

    /**
     * 处理通知消息
     */
    private void handleNotificationMessage(WebSocketSession session, String userId, Map<String, Object> messageData) {
        // 处理通知消息的逻辑
        log.info("处理通知消息: userId={}, message={}", userId, messageData);
    }

    /**
     * 发送消息到指定会话
     */
    private void sendMessage(WebSocketSession session, Object message) throws IOException {
        if (session.isOpen()) {
            String jsonMessage = objectMapper.writeValueAsString(message);
            session.sendMessage(new TextMessage(jsonMessage));
        }
    }

    /**
     * 从会话中获取用户ID
     */
    private String getUserId(WebSocketSession session) {
        return (String) session.getAttributes().get("userId");
    }

    /**
     * 获取所有活跃会话
     */
    public Map<String, WebSocketSession> getSessions() {
        return sessions;
    }
}
